{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OneHotEncoder\n",
    "Performs One Hot Encoding.\n",
    "\n",
    "The encoder can select how many different labels per variable to encode into binaries. When top_categories is set to None, all the categories will be transformed in binary variables. \n",
    "\n",
    "However, when top_categories is set to an integer, for example 10, then only the 10 most popular categories will be transformed into binary, and the rest will be discarded.\n",
    "\n",
    "The encoder has also the possibility to create binary variables from all categories (drop_last = False), or remove the binary for the last category (drop_last = True), for use in linear models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from feature_engine.encoding import OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load titanic dataset from OpenML\n",
    "\n",
    "def load_titanic():\n",
    "    data = pd.read_csv('https://www.openml.org/data/get_csv/16826755/phpMYEkMl')\n",
    "    data = data.replace('?', np.nan)\n",
    "    data['cabin'] = data['cabin'].astype(str).str[0]\n",
    "    data['pclass'] = data['pclass'].astype('O')\n",
    "    data['age'] = data['age'].astype('float')\n",
    "    data['fare'] = data['fare'].astype('float')\n",
    "    data['embarked'].fillna('C', inplace=True)\n",
    "    data.drop(labels=['boat', 'body', 'home.dest'], axis=1, inplace=True)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pclass</th>\n",
       "      <th>survived</th>\n",
       "      <th>name</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>ticket</th>\n",
       "      <th>fare</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Allen, Miss. Elisabeth Walton</td>\n",
       "      <td>female</td>\n",
       "      <td>29.0000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>24160</td>\n",
       "      <td>211.3375</td>\n",
       "      <td>B</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Master. Hudson Trevor</td>\n",
       "      <td>male</td>\n",
       "      <td>0.9167</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Miss. Helen Loraine</td>\n",
       "      <td>female</td>\n",
       "      <td>2.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Mr. Hudson Joshua Creighton</td>\n",
       "      <td>male</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Mrs. Hudson J C (Bessie Waldo Daniels)</td>\n",
       "      <td>female</td>\n",
       "      <td>25.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  pclass  survived                                             name     sex  \\\n",
       "0      1         1                    Allen, Miss. Elisabeth Walton  female   \n",
       "1      1         1                   Allison, Master. Hudson Trevor    male   \n",
       "2      1         0                     Allison, Miss. Helen Loraine  female   \n",
       "3      1         0             Allison, Mr. Hudson Joshua Creighton    male   \n",
       "4      1         0  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female   \n",
       "\n",
       "       age  sibsp  parch  ticket      fare cabin embarked  \n",
       "0  29.0000      0      0   24160  211.3375     B        S  \n",
       "1   0.9167      1      2  113781  151.5500     C        S  \n",
       "2   2.0000      1      2  113781  151.5500     C        S  \n",
       "3  30.0000      1      2  113781  151.5500     C        S  \n",
       "4  25.0000      1      2  113781  151.5500     C        S  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = load_titanic()\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data.drop(['survived', 'name', 'ticket'], axis=1)\n",
    "y = data.survived"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "cabin       0\n",
       "pclass      0\n",
       "embarked    0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we will encode the below variables, they have no missing values\n",
    "X[['cabin', 'pclass', 'embarked']].isnull().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "cabin       object\n",
       "pclass      object\n",
       "embarked    object\n",
       "dtype: object"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "''' Make sure that the variables are type (object).\n",
    "if not, cast it as object , otherwise the transformer will either send an error (if we pass it as argument) \n",
    "or not pick it up (if we leave variables=None). '''\n",
    "\n",
    "X[['cabin', 'pclass', 'embarked']].dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((916, 8), (393, 8))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's separate into training and testing set\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)\n",
    "\n",
    "X_train.shape, X_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One hot encoding consists in replacing the categorical variable by a\n",
    "combination of binary variables which take value 0 or 1, to indicate if\n",
    "a certain category is present in an observation.\n",
    "\n",
    "Each one of the binary variables are also known as dummy variables. For\n",
    "example, from the categorical variable \"Gender\" with categories 'female'\n",
    "and 'male', we can generate the boolean variable \"female\", which takes 1\n",
    "if the person is female or 0 otherwise. We can also generate the variable\n",
    "male, which takes 1 if the person is \"male\" and 0 otherwise.\n",
    "\n",
    "The encoder has the option to generate one dummy variable per category, or\n",
    "to create dummy variables only for the top n most popular categories, that is,\n",
    "the categories that are shown by the majority of the observations.\n",
    "\n",
    "If dummy variables are created for all the categories of a variable, you have\n",
    "the option to drop one category not to create information redundancy. That is,\n",
    "encoding into k-1 variables, where k is the number if unique categories.\n",
    "\n",
    "The encoder will encode only categorical variables (type 'object'). A list\n",
    "of variables can be passed as an argument. If no variables are passed as \n",
    "argument, the encoder will find and encode categorical variables (object type).\n",
    "\n",
    "\n",
    "#### Note:\n",
    "New categories in the data to transform, that is, those that did not appear\n",
    "in the training set, will be ignored (no binary variable will be created for them).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### All binary, no top_categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OneHotEncoder(variables=['pclass', 'cabin', 'embarked'])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "Parameters\n",
    "----------\n",
    "\n",
    "top_categories: int, default=None\n",
    "    If None, a dummy variable will be created for each category of the variable.\n",
    "    Alternatively, top_categories indicates the number of most frequent categories\n",
    "    to encode. Dummy variables will be created only for those popular categories\n",
    "    and the rest will be ignored. Note that this is equivalent to grouping all the\n",
    "    remaining categories in one group.\n",
    "    \n",
    "variables : list\n",
    "    The list of categorical variables that will be encoded. If None, the  \n",
    "    encoder will find and select all object type variables.\n",
    "    \n",
    "drop_last: boolean, default=False\n",
    "    Only used if top_categories = None. It indicates whether to create dummy\n",
    "    variables for all the categories (k dummies), or if set to True, it will\n",
    "    ignore the last variable of the list (k-1 dummies).\n",
    "'''\n",
    "ohe_enc = OneHotEncoder(top_categories=None,\n",
    "                        variables=['pclass', 'cabin', 'embarked'],\n",
    "                        drop_last=False)\n",
    "ohe_enc.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pclass': array([2, 3, 1], dtype=object),\n",
       " 'cabin': array(['n', 'E', 'C', 'D', 'B', 'A', 'F', 'T', 'G'], dtype=object),\n",
       " 'embarked': array(['S', 'C', 'Q'], dtype=object)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ohe_enc.encoder_dict_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>pclass_2</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>pclass_1</th>\n",
       "      <th>cabin_n</th>\n",
       "      <th>cabin_E</th>\n",
       "      <th>cabin_C</th>\n",
       "      <th>cabin_D</th>\n",
       "      <th>cabin_B</th>\n",
       "      <th>cabin_A</th>\n",
       "      <th>cabin_F</th>\n",
       "      <th>cabin_T</th>\n",
       "      <th>cabin_G</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>embarked_C</th>\n",
       "      <th>embarked_Q</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>female</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>19.5000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>588</th>\n",
       "      <td>female</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>23.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>female</td>\n",
       "      <td>30.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>13.8583</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1193</th>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>686</th>\n",
       "      <td>female</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         sex   age  sibsp  parch     fare  pclass_2  pclass_3  pclass_1  \\\n",
       "501   female  13.0      0      1  19.5000         1         0         0   \n",
       "588   female   4.0      1      1  23.0000         1         0         0   \n",
       "402   female  30.0      1      0  13.8583         1         0         0   \n",
       "1193    male   NaN      0      0   7.7250         0         1         0   \n",
       "686   female  22.0      0      0   7.7250         0         1         0   \n",
       "\n",
       "      cabin_n  cabin_E  cabin_C  cabin_D  cabin_B  cabin_A  cabin_F  cabin_T  \\\n",
       "501         1        0        0        0        0        0        0        0   \n",
       "588         1        0        0        0        0        0        0        0   \n",
       "402         1        0        0        0        0        0        0        0   \n",
       "1193        1        0        0        0        0        0        0        0   \n",
       "686         1        0        0        0        0        0        0        0   \n",
       "\n",
       "      cabin_G  embarked_S  embarked_C  embarked_Q  \n",
       "501         0           1           0           0  \n",
       "588         0           1           0           0  \n",
       "402         0           0           1           0  \n",
       "1193        0           0           0           1  \n",
       "686         0           0           0           1  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_t = ohe_enc.transform(X_train)\n",
    "test_t = ohe_enc.transform(X_train)\n",
    "\n",
    "test_t.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Selecting top_categories to encode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pclass': [3, 1], 'cabin': ['n', 'C'], 'embarked': ['S', 'C']}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ohe_enc = OneHotEncoder(top_categories=2,\n",
    "                        variables=['pclass', 'cabin', 'embarked'],\n",
    "                        drop_last=False)\n",
    "ohe_enc.fit(X_train)\n",
    "ohe_enc.encoder_dict_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>pclass_1</th>\n",
       "      <th>cabin_n</th>\n",
       "      <th>cabin_C</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>embarked_C</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>female</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>19.5000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>588</th>\n",
       "      <td>female</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>23.0000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>female</td>\n",
       "      <td>30.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>13.8583</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1193</th>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>686</th>\n",
       "      <td>female</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         sex   age  sibsp  parch     fare  pclass_3  pclass_1  cabin_n  \\\n",
       "501   female  13.0      0      1  19.5000         0         0        1   \n",
       "588   female   4.0      1      1  23.0000         0         0        1   \n",
       "402   female  30.0      1      0  13.8583         0         0        1   \n",
       "1193    male   NaN      0      0   7.7250         1         0        1   \n",
       "686   female  22.0      0      0   7.7250         1         0        1   \n",
       "\n",
       "      cabin_C  embarked_S  embarked_C  \n",
       "501         0           1           0  \n",
       "588         0           1           0  \n",
       "402         0           0           1  \n",
       "1193        0           0           0  \n",
       "686         0           0           0  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_t = ohe_enc.transform(X_train)\n",
    "test_t = ohe_enc.transform(X_train)\n",
    "test_t.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dropping the last category for linear models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pclass': [2, 3],\n",
       " 'cabin': ['n', 'E', 'C', 'D', 'B', 'A', 'F', 'T'],\n",
       " 'embarked': ['S', 'C']}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ohe_enc = OneHotEncoder(top_categories=None,\n",
    "                        variables=['pclass', 'cabin', 'embarked'],\n",
    "                        drop_last=True)\n",
    "\n",
    "ohe_enc.fit(X_train)\n",
    "ohe_enc.encoder_dict_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>pclass_2</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>cabin_n</th>\n",
       "      <th>cabin_E</th>\n",
       "      <th>cabin_C</th>\n",
       "      <th>cabin_D</th>\n",
       "      <th>cabin_B</th>\n",
       "      <th>cabin_A</th>\n",
       "      <th>cabin_F</th>\n",
       "      <th>cabin_T</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>embarked_C</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>female</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>19.5000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>588</th>\n",
       "      <td>female</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>23.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>female</td>\n",
       "      <td>30.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>13.8583</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1193</th>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>686</th>\n",
       "      <td>female</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         sex   age  sibsp  parch     fare  pclass_2  pclass_3  cabin_n  \\\n",
       "501   female  13.0      0      1  19.5000         1         0        1   \n",
       "588   female   4.0      1      1  23.0000         1         0        1   \n",
       "402   female  30.0      1      0  13.8583         1         0        1   \n",
       "1193    male   NaN      0      0   7.7250         0         1        1   \n",
       "686   female  22.0      0      0   7.7250         0         1        1   \n",
       "\n",
       "      cabin_E  cabin_C  cabin_D  cabin_B  cabin_A  cabin_F  cabin_T  \\\n",
       "501         0        0        0        0        0        0        0   \n",
       "588         0        0        0        0        0        0        0   \n",
       "402         0        0        0        0        0        0        0   \n",
       "1193        0        0        0        0        0        0        0   \n",
       "686         0        0        0        0        0        0        0   \n",
       "\n",
       "      embarked_S  embarked_C  \n",
       "501            1           0  \n",
       "588            1           0  \n",
       "402            0           1  \n",
       "1193           0           0  \n",
       "686            0           0  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_t = ohe_enc.transform(X_train)\n",
    "test_t = ohe_enc.transform(X_train)\n",
    "\n",
    "test_t.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatically select categorical variables\n",
    "\n",
    "This encoder selects all the categorical variables, if None is passed to the variable argument when calling the encoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OneHotEncoder(drop_last=True, variables=['pclass', 'sex', 'cabin', 'embarked'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ohe_enc = OneHotEncoder(top_categories=None,\n",
    "                        drop_last=True)\n",
    "\n",
    "ohe_enc.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>pclass_2</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>sex_female</th>\n",
       "      <th>cabin_n</th>\n",
       "      <th>cabin_E</th>\n",
       "      <th>cabin_C</th>\n",
       "      <th>cabin_D</th>\n",
       "      <th>cabin_B</th>\n",
       "      <th>cabin_A</th>\n",
       "      <th>cabin_F</th>\n",
       "      <th>cabin_T</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>embarked_C</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>13.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>19.5000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>588</th>\n",
       "      <td>4.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>23.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>30.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>13.8583</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1193</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>686</th>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7250</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       age  sibsp  parch     fare  pclass_2  pclass_3  sex_female  cabin_n  \\\n",
       "501   13.0      0      1  19.5000         1         0           1        1   \n",
       "588    4.0      1      1  23.0000         1         0           1        1   \n",
       "402   30.0      1      0  13.8583         1         0           1        1   \n",
       "1193   NaN      0      0   7.7250         0         1           0        1   \n",
       "686   22.0      0      0   7.7250         0         1           1        1   \n",
       "\n",
       "      cabin_E  cabin_C  cabin_D  cabin_B  cabin_A  cabin_F  cabin_T  \\\n",
       "501         0        0        0        0        0        0        0   \n",
       "588         0        0        0        0        0        0        0   \n",
       "402         0        0        0        0        0        0        0   \n",
       "1193        0        0        0        0        0        0        0   \n",
       "686         0        0        0        0        0        0        0   \n",
       "\n",
       "      embarked_S  embarked_C  \n",
       "501            1           0  \n",
       "588            1           0  \n",
       "402            0           1  \n",
       "1193           0           0  \n",
       "686            0           0  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_t = ohe_enc.transform(X_train)\n",
    "test_t = ohe_enc.transform(X_train)\n",
    "\n",
    "test_t.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
